{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/harvardnlp/pytorch-struct/blob/master/notebooks/BertTagger.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example: Tagger + Bert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "TpuauEFHAptc",
    "outputId": "5b42c411-8d60-4b2a-aa6d-bccdd06bcb64"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Building wheel for torch-struct (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
     ]
    }
   ],
   "source": [
    "!pip install -qqq torchtext wandb pytorch-transformers\n",
    "!pip install -qqqU git+https://github.com/harvardnlp/pytorch-struct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "colab_type": "code",
    "id": "wF4__YMPArMf",
    "outputId": "ef1fc649-3e59-4056-a744-c5451a5b9947"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "            Notebook configured with <a href=\"https://wandb.com\" target=\"_blank\">W&B</a>. You can <a href=\"https://app.wandb.ai/srush/pytorch-struct-tagging/runs/67eo4ise\" target=\"_blank\">open</a> the run page, or call <code>%%wandb</code>\n",
       "            in a cell containing your training loop to display live results.  Learn more in our <a href=\"https://docs.wandb.com/docs/integrations/jupyter.html\" target=\"_blank\">docs</a>.\n",
       "        "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "W&B Run: https://app.wandb.ai/srush/pytorch-struct-tagging/runs/67eo4ise"
      ]
     },
     "execution_count": 2,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torchtext\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch_struct import LinearChainCRF\n",
    "import torch_struct.data\n",
    "from pytorch_transformers import *\n",
    "config = {\"bert\": \"bert-base-cased\", \"H\" : 768, \"dropout\": 0.2}\n",
    "\n",
    "# Comment or add your wandb\n",
    "#import wandb\n",
    "#wandb.init(project=\"pytorch-struct-tagging\", config=config)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-rCPFdKJ-n4q"
   },
   "source": [
    "Setup data and batching."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "VWrGLRmJC89K",
    "outputId": "27d0bb7f-e922-462a-e6e7-90d33c834b5e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "error\n"
     ]
    }
   ],
   "source": [
    "model_class, tokenizer_class, pretrained_weights = BertModel, BertTokenizer, config[\"bert\"]\n",
    "tokenizer = tokenizer_class.from_pretrained(pretrained_weights)    \n",
    "WORD = torch_struct.data.SubTokenizedField(tokenizer)\n",
    "UD_TAG = torchtext.data.Field(init_token=\"<bos>\", eos_token=\"<eos>\", include_lengths=True)\n",
    "\n",
    "train, val, test = torchtext.datasets.UDPOS.splits(\n",
    "    fields=(('word', WORD), ('udtag', UD_TAG), (None, None)), \n",
    "    filter_pred=lambda ex: len(ex.word[0]) < 200\n",
    ")\n",
    "\n",
    "#WORD.build_vocab(train.word, min_freq=3)\n",
    "UD_TAG.build_vocab(train.udtag)\n",
    "train_iter = torch_struct.data.TokenBucket(train, 750)\n",
    "val_iter = torchtext.data.BucketIterator(val, \n",
    "    batch_size=10,\n",
    "    device=\"cuda:0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ss7plawBHR_A"
   },
   "source": [
    "Setup transformer and a simple one-layer model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jvUmy1gZAuZr"
   },
   "outputs": [],
   "source": [
    "C = len(UD_TAG.vocab)\n",
    "\n",
    "\n",
    "class Model(nn.Module):\n",
    "    def __init__(self, hidden, classes):\n",
    "        super().__init__()\n",
    "        self.base_model = model_class.from_pretrained(pretrained_weights)\n",
    "        self.linear = nn.Linear(hidden, C)\n",
    "        self.transition = nn.Linear(C, C)\n",
    "        self.dropout = nn.Dropout(config[\"dropout\"])\n",
    "        \n",
    "    def forward(self, words, mapper):\n",
    "        out = self.dropout(self.base_model(words)[0])\n",
    "        out = torch.einsum(\"bca,bch->bah\", mapper.float().cuda(), out)\n",
    "        final = torch.einsum(\"bnh,ch->bnc\", out, self.linear.weight)\n",
    "        batch, N, C = final.shape\n",
    "        vals = final.view(batch, N, C, 1)[:, 1:N] + self.transition.weight.view(1, 1, C, C)\n",
    "        vals[:, 0, :, :] += final.view(batch, N, 1, C)[:, 0] \n",
    "        return vals\n",
    "model = Model(config[\"H\"], C)\n",
    "wandb.watch(model)\n",
    "model.cuda()\n",
    "None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WuCDtMSk-x2R"
   },
   "source": [
    "Generic train validation loop. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "iEv1xeZl1Qux"
   },
   "outputs": [],
   "source": [
    "def validate(itera):\n",
    "    incorrect_edges = 0\n",
    "    total = 0 \n",
    "    model.eval()\n",
    "    for i, ex in enumerate(itera):\n",
    "        words, mapper, _ = ex.word\n",
    "        label, lengths = ex.udtag\n",
    "        dist = LinearChainCRF(model(words.cuda(), mapper),\n",
    "                              lengths=lengths)        \n",
    "        argmax = dist.argmax\n",
    "        gold = LinearChainCRF.struct.to_parts(label.transpose(0, 1), C,\n",
    "                                              lengths=lengths).type_as(argmax)\n",
    "        incorrect_edges += (argmax.sum(-1) - gold.sum(-1)).abs().sum() / 2.0\n",
    "        total += argmax.sum()            \n",
    "        \n",
    "    model.train()    \n",
    "    return incorrect_edges / total   \n",
    "    \n",
    "def train(train_iter, val_iter, model):\n",
    "    opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8)\n",
    "    scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)\n",
    "\n",
    "    model.train()\n",
    "    losses = []\n",
    "    for i, ex in enumerate(train_iter):\n",
    "        opt.zero_grad()\n",
    "        words, mapper, _ = ex.word\n",
    "        label, lengths = ex.udtag\n",
    "        N_1, batch = label.shape\n",
    "\n",
    "        # Model\n",
    "        log_potentials = model(words.cuda(), mapper)\n",
    "        if not lengths.max() <= log_potentials.shape[1] + 1:\n",
    "            print(\"fail\")\n",
    "            continue\n",
    "\n",
    "        dist = LinearChainCRF(log_potentials,\n",
    "                              lengths=lengths.cuda())    \n",
    "\n",
    "        \n",
    "        labels = LinearChainCRF.struct.to_parts(label.transpose(0, 1), C, lengths=lengths) \\\n",
    "                            .type_as(dist.log_potentials)\n",
    "        loss = dist.log_prob(labels).sum()\n",
    "        (-loss).backward()\n",
    "        \n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "        opt.step()\n",
    "        scheduler.step()\n",
    "\n",
    "        losses.append(loss.detach())\n",
    "        \n",
    "        \n",
    "        if i % 100 == 10:            \n",
    "            print(-torch.tensor(losses).mean(), words.shape)\n",
    "            val_loss = validate(val_iter)\n",
    "            wandb.log({\"train_loss\":-torch.tensor(losses).mean(), \n",
    "                       \"val_loss\" : val_loss})\n",
    "            losses = []\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "colab_type": "code",
    "id": "uvb0rbD90gJw",
    "outputId": "d5af07ff-238c-4876-8039-4662a8109b9b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1828.1356) torch.Size([14, 54])\n",
      "tensor(608.0067) torch.Size([27, 27])\n",
      "tensor(286.6919) torch.Size([27, 28])\n",
      "fail\n",
      "tensor(234.0003) torch.Size([12, 63])\n",
      "tensor(194.0490) torch.Size([44, 17])\n",
      "tensor(140.0352) torch.Size([24, 31])\n"
     ]
    }
   ],
   "source": [
    "train(train_iter, val_iter, model.cuda()) "
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "PyTorch Tagging.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
